#!/usr/bin/env python3
"""
Phase 6: Country Prediction & Scorecards
Tables 17-19: Scorecards, counterfactuals, EU accession event study.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Country Sets ───────────────────────────────────────────────────────────
CCA_ALL = {'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
           'TJK', 'TKM', 'UKR', 'UZB'}
CCA_NON_COMMODITY = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}
BALTIC = {'EST', 'LVA', 'LTU'}
CEE = {'ALB', 'BGR', 'BIH', 'HRV', 'CZE', 'HUN', 'MKD', 'MNE', 'POL',
       'ROU', 'SRB', 'SVK', 'SVN'}

# EU accession dates
EU_ACCESSION = {
    'EST': 2004, 'LVA': 2004, 'LTU': 2004,
    'CZE': 2004, 'HUN': 2004, 'POL': 2004, 'SVK': 2004, 'SVN': 2004,
    'BGR': 2007, 'ROU': 2007,
    'HRV': 2013,
}

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
    }

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 6: PREDICTION & SCORECARDS")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 17: COUNTRY SCORECARDS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 17: COUNTRY SCORECARDS")
print("=" * 70)

# Use governance interaction model to compute Z₁ contribution
gov_sample = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
gov_sample['gov_dm'] = gov_sample['governance_composite'] - gov_sample['governance_composite'].mean()
for z in demo_vars:
    gov_sample[f'{z}_x_gov'] = gov_sample[z] * gov_sample['gov_dm']
gov_int_vars = [f'{z}_x_gov' for z in demo_vars]
full_vars = base_vars + ['gov_dm'] + gov_int_vars

r_gov = run_gls(gov_sample, 'ca_gdp', full_vars)

# Read Hansen threshold from phase 5 output
try:
    hansen_df = pd.read_csv(TABLE_DIR / "table16_hansen_threshold.csv")
    best_threshold = hansen_df.loc[hansen_df['ssr_total'].idxmin(), 'threshold']
except Exception:
    best_threshold = gov_sample['governance_composite'].median()
    print(f"  Warning: using median governance as threshold ({best_threshold:.2f})")

# Baltic mean governance as "lifecycle" benchmark
baltic_gov = gov_sample[gov_sample['iso3'].isin(BALTIC)]['governance_composite'].mean()
print(f"  Baltic mean governance: {baltic_gov:.3f}")
print(f"  Hansen threshold: {best_threshold:.3f}")

# Compute scorecards for all transition countries
transition_countries = sorted((CCA_ALL | BALTIC | CEE) & set(gov_sample['iso3'].unique()))
scorecard_rows = []

for iso in transition_countries:
    sub = gov_sample[gov_sample['iso3'] == iso]
    if len(sub) == 0:
        continue

    # Use most recent 5 years
    recent = sub[sub['year'] >= sub['year'].max() - 4]
    mean_gov = recent['governance_composite'].mean()
    mean_z1 = recent['Z_1'].mean()
    mean_ca = recent['ca_gdp'].mean()
    mean_kaopen = recent['kaopen'].mean()

    # Z₁ contribution: Z₁ * β_Z1 + Z₁ * gov_dm * β_Z1xgov
    if r_gov:
        z1_direct = mean_z1 * r_gov['coefficients']['Z_1']
        gov_dm_val = mean_gov - gov_sample['governance_composite'].mean()
        z1_interaction = mean_z1 * gov_dm_val * r_gov['coefficients']['Z_1_x_gov']
        z1_total = z1_direct + z1_interaction
    else:
        z1_total = np.nan

    # Distance to lifecycle regime
    gov_gap = baltic_gov - mean_gov
    threshold_gap = best_threshold - mean_gov

    group = 'CCA_commodity' if iso in CCA_ALL - CCA_NON_COMMODITY else \
            'CCA_non_commodity' if iso in CCA_NON_COMMODITY else \
            'Baltic' if iso in BALTIC else 'CEE'

    scorecard_rows.append({
        'iso3': iso,
        'group': group,
        'mean_governance': mean_gov,
        'mean_Z_1': mean_z1,
        'mean_ca_gdp': mean_ca,
        'mean_kaopen': mean_kaopen,
        'Z1_contribution': z1_total,
        'gov_gap_to_baltic': gov_gap,
        'gov_gap_to_threshold': threshold_gap,
        'above_threshold': mean_gov > best_threshold,
    })

scorecard_df = pd.DataFrame(scorecard_rows)
scorecard_df = scorecard_df.sort_values('mean_governance')
scorecard_df.to_csv(TABLE_DIR / "table17_scorecards.csv", index=False)

print(f"\n  {'ISO':<5s} {'Group':<18s} {'Gov':>6s} {'Gap→Baltic':>10s} {'Gap→τ':>7s} {'Z₁':>6s} {'Z₁→CA':>7s} {'CA/GDP':>7s}")
print("  " + "-" * 75)
for _, row in scorecard_df.iterrows():
    above = "✓" if row['above_threshold'] else " "
    print(f"  {row['iso3']:<5s} {row['group']:<18s} {row['mean_governance']:6.2f} "
          f"{row['gov_gap_to_baltic']:10.2f} {row['gov_gap_to_threshold']:7.2f}{above} "
          f"{row['mean_Z_1']:6.4f} {row['Z1_contribution']:7.2f} {row['mean_ca_gdp']:7.2f}")

# Markdown
md = ["# Table 17: Transition Country Scorecards\n"]
md.append(f"Baltic governance benchmark: {baltic_gov:.3f}")
md.append(f"Hansen threshold: {best_threshold:.3f}\n")
md.append("| Country | Group | Governance | Gap to Baltic | Gap to τ | Z₁ | Z₁→CA | CA/GDP |")
md.append("|---|---|---|---|---|---|---|---|")
for _, row in scorecard_df.iterrows():
    md.append(f"| {row['iso3']} | {row['group']} | {row['mean_governance']:.2f} | "
              f"{row['gov_gap_to_baltic']:.2f} | {row['gov_gap_to_threshold']:.2f} | "
              f"{row['mean_Z_1']:.4f} | {row['Z1_contribution']:.2f} | {row['mean_ca_gdp']:.2f} |")
(TABLE_DIR / "table17_scorecards.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 18: COUNTERFACTUALS — CCA AT BALTIC GOVERNANCE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 18: COUNTERFACTUALS")
print("=" * 70)

if r_gov:
    cf_rows = []
    beta_z1_gov = r_gov['coefficients']['Z_1_x_gov']
    gov_mean = gov_sample['governance_composite'].mean()

    for iso in sorted(CCA_ALL & set(gov_sample['iso3'].unique())):
        sub = gov_sample[gov_sample['iso3'] == iso]
        recent = sub[sub['year'] >= sub['year'].max() - 4]
        if len(recent) == 0:
            continue

        actual_gov = recent['governance_composite'].mean()
        actual_z1 = recent['Z_1'].mean()
        actual_ca = recent['ca_gdp'].mean()

        # Counterfactual: same demographics, Baltic governance
        actual_gov_dm = actual_gov - gov_mean
        baltic_gov_dm = baltic_gov - gov_mean

        # Change in Z₁ contribution = Z₁ * (baltic_gov_dm - actual_gov_dm) * β_interaction
        delta_ca = actual_z1 * (baltic_gov_dm - actual_gov_dm) * beta_z1_gov

        cf_rows.append({
            'iso3': iso,
            'actual_governance': actual_gov,
            'actual_ca_gdp': actual_ca,
            'baltic_governance': baltic_gov,
            'delta_ca_gdp': delta_ca,
            'counterfactual_ca_gdp': actual_ca + delta_ca,
            'Z_1': actual_z1,
        })

    cf_df = pd.DataFrame(cf_rows)
    cf_df.to_csv(TABLE_DIR / "table18_counterfactuals.csv", index=False)

    print(f"\n  {'ISO':<5s} {'Actual Gov':>10s} {'Baltic Gov':>10s} {'Actual CA':>10s} {'ΔCA':>8s} {'CF CA':>8s}")
    print("  " + "-" * 55)
    for _, row in cf_df.iterrows():
        print(f"  {row['iso3']:<5s} {row['actual_governance']:10.2f} {row['baltic_governance']:10.2f} "
              f"{row['actual_ca_gdp']:10.2f} {row['delta_ca_gdp']:8.2f} {row['counterfactual_ca_gdp']:8.2f}")

    md = ["# Table 18: Counterfactual — CCA Countries at Baltic Governance\n"]
    md.append(f"Interaction coefficient (Z₁×gov): {beta_z1_gov:.3f}{star(r_gov['p_values']['Z_1_x_gov'])}\n")
    md.append("| Country | Actual Gov | Baltic Gov | Actual CA/GDP | ΔCA/GDP | Counterfactual CA/GDP |")
    md.append("|---|---|---|---|---|---|")
    for _, row in cf_df.iterrows():
        md.append(f"| {row['iso3']} | {row['actual_governance']:.2f} | {row['baltic_governance']:.2f} | "
                  f"{row['actual_ca_gdp']:.2f} | {row['delta_ca_gdp']:.2f} | {row['counterfactual_ca_gdp']:.2f} |")
    (TABLE_DIR / "table18_counterfactuals.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 19: EU ACCESSION EVENT STUDY
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 19: EU ACCESSION EVENT STUDY")
print("=" * 70)

# Countries that joined the EU — DiD: does Z₁ coefficient change around accession?
accession_countries = set(EU_ACCESSION.keys()) & set(est['iso3'].unique())
print(f"  EU accession countries in sample: {sorted(accession_countries)}")

# Create post-accession indicator and interaction
est_eu = est.copy()
est_eu['eu_accession_year'] = est_eu['iso3'].map(EU_ACCESSION)
est_eu['post_accession'] = (est_eu['year'] >= est_eu['eu_accession_year']).astype(float)
est_eu.loc[est_eu['eu_accession_year'].isna(), 'post_accession'] = np.nan

# Two approaches:
# A) Full sample with post_accession × Z interactions
eu_sample = est_eu.dropna(subset=base_vars + ['ca_gdp', 'post_accession']).copy()
# Only accession countries have meaningful variation
eu_treated = eu_sample[eu_sample['iso3'].isin(accession_countries)]

# DiD: restrict to accession countries, compare pre/post
eu_pre = eu_treated[eu_treated['post_accession'] == 0]
eu_post = eu_treated[eu_treated['post_accession'] == 1]

r_pre = run_gls(eu_pre, 'ca_gdp', base_vars)
r_post = run_gls(eu_post, 'ca_gdp', base_vars)

eu_rows = []
for name, r in [("Pre-accession", r_pre), ("Post-accession", r_post)]:
    if r is None:
        print(f"  {name}: insufficient obs")
        continue
    row = {'period': name, 'n_countries': r['n_countries'], 'n_obs': r['n_obs'],
           'r_squared': r['r_squared']}
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    eu_rows.append(row)
    print(f"  {name}: Z₁={row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} (p={row['Z_1_pval']:.4f}), "
          f"N_c={r['n_countries']}, N={r['n_obs']}")

# B) Interaction approach on full sample (all countries, with post_accession×Z₁)
eu_int_sample = est_eu.copy()
eu_int_sample['post_acc'] = eu_int_sample['post_accession'].fillna(0)
eu_int_sample['is_accession'] = eu_int_sample['iso3'].isin(accession_countries).astype(float)
eu_int_sample['Z_1_x_post_acc'] = eu_int_sample['Z_1'] * eu_int_sample['post_acc']
eu_int_sample['Z_1_x_is_acc'] = eu_int_sample['Z_1'] * eu_int_sample['is_accession']
eu_int_sample['Z_1_x_acc_post'] = eu_int_sample['Z_1'] * eu_int_sample['is_accession'] * eu_int_sample['post_acc']

# Triple-diff: Z₁ × is_accession × post_accession
triple_vars = base_vars + ['is_accession', 'post_acc', 'Z_1_x_is_acc', 'Z_1_x_post_acc', 'Z_1_x_acc_post']
r_triple = run_gls(eu_int_sample, 'ca_gdp', triple_vars)
if r_triple:
    z1_acc_post = r_triple['coefficients']['Z_1_x_acc_post']
    z1_acc_post_p = r_triple['p_values']['Z_1_x_acc_post']
    print(f"\n  Triple-diff (Z₁ × accession × post): {z1_acc_post:.2f}{star(z1_acc_post_p)} (p={z1_acc_post_p:.4f})")
    eu_rows.append({'period': 'Triple-diff interaction', 'Z_1_x_acc_post': z1_acc_post,
                    'Z_1_x_acc_post_pval': z1_acc_post_p,
                    'n_countries': r_triple['n_countries'], 'n_obs': r_triple['n_obs'],
                    'r_squared': r_triple['r_squared']})

pd.DataFrame(eu_rows).to_csv(TABLE_DIR / "table19_eu_accession.csv", index=False)

md = ["# Table 19: EU Accession Event Study\n"]
md.append("| Period | N_c | N_obs | Z₁ | SE | p-val | R² |")
md.append("|---|---|---|---|---|---|---|")
for r in eu_rows:
    z1 = r.get('Z_1_coef', r.get('Z_1_x_acc_post', np.nan))
    se = r.get('Z_1_se', np.nan)
    p = r.get('Z_1_pval', r.get('Z_1_x_acc_post_pval', np.nan))
    se_str = f"{se:.2f}" if pd.notna(se) else "—"
    p_str = f"{p:.4f}" if pd.notna(p) else "—"
    star_str = star(p) if pd.notna(p) else ""
    md.append(f"| {r['period']} | {r['n_countries']} | {r['n_obs']} | "
              f"{z1:.2f}{star_str} | {se_str} | "
              f"{p_str} | {r['r_squared']:.4f} |")
(TABLE_DIR / "table19_eu_accession.md").write_text("\n".join(md))

print(f"\nAll Phase 6 tables saved to {TABLE_DIR}")
print("Phase 6 complete.")
